"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 2. dAMF of different c, h, and s
import numpy as np
import matplotlib.pyplot as plt
import h5py

def get_yticks(view_lim):
    view_length = view_lim[1]-view_lim[0]
    if view_length>3:
        yticks = np.arange(np.floor(view_lim)[0],
                           np.floor(view_lim)[1]) + 1
    elif view_length>1.5:
        yticks = np.arange(np.floor(view_lim*2)[0],
                           np.floor(view_lim*2)[1])/2 + 0.5
    elif view_length>0.8:
        yticks = np.arange(np.floor(view_lim*5)[0],
                           np.floor(view_lim*5)[1])/5 + 0.2
    else:
        yticks = np.arange(np.floor(np.mean(view_lim)*10)/10-0.3, np.floor(np.mean(view_lim)*10)/10+0.4, 0.2)
    return yticks

# fname = r'E:\MAPALUTs\O4LUT_360nm.nc'
# LUT, LUT_parameters = mapalook_up.get_lut_nc(fname)
#
# LUT = LUT[:].transpose((0, 1, 5, 2, 3, 4))
f1 = h5py.File('O4LUT_360nm_damf.h5', 'r')
LUT_parameters = {}
for i in f1.keys():
    LUT_parameters[i] = f1[i][:]
LUT = LUT_parameters.pop('LUT')

c = LUT_parameters['c']
# [0.   0.04 0.1  0.18 0.29 0.45 0.67 0.93 1.24 1.6  2.4  4.  ]
h = LUT_parameters['h']
# [0.1  0.22 0.34 0.48 0.64 0.9  1.24 1.6  2.   2.5  3.2  4.  ]
s = LUT_parameters['s']
# [0.  0.2 0.4 0.6 0.8 1.  1.1 1.2 1.4 1.6 1.8]
sza = LUT_parameters['sza']
#[ 5 15 25 33 41 48 54 60 65 70 75 80]
raa = LUT_parameters['raa']
#[  0   5  10  20  30  45  65  90 120 150 180]
elev = LUT_parameters['elev']
#[ 1  2  3  4  5  6  8 10 15 30 90]
colors = [
    '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
    '#ffbb78', '#98df8a'
]
pick_sza = [3]
pick_raa = [-3]
pick_elev = [0, -2]
apart=1
if apart==1:
    order = 'abcdefghijklmnopqrstuvwxyz'
    num = 0
    # pick_sza
    for i in pick_sza:
        #pick_raa
        for j in pick_raa:
            # pick_elev
            for k in pick_elev:
                lut = LUT[i, j, k]
                vmax = lut.max() * 1.05
                vmin = lut.min() * 1.05
                plt.rcParams['font.size'] = 15
                fig, axs = plt.subplots(1, 4, figsize=(20, 4))
                plt.subplots_adjust(top=0.95, bottom=0.15, right=0.915, left=0.04, hspace=0.12, wspace=0.17)
                axs = axs.flatten()
                for n in range(len(axs)):
                    sec = [1, 5, 6, -1]
                    for o in range(lut.shape[1]):
                        axs[n].plot(c, lut[:, o, sec[n]], '-v', label='h={}'.format(h[o]), color=colors[o])
                    axs[n].set_xlabel('c')
                    axs[n].grid(linestyle='-.', linewidth=1)
                    axs[n].annotate('s={}'.format(s[sec[n]]), xy=(0.5, 0.9), xycoords='axes fraction',
                              horizontalalignment='center', verticalalignment='top')
                    axs[n].annotate('({})'.format(order[n+num]), xy=(0.05, 0.95), xycoords='axes fraction',
                              horizontalalignment='center', verticalalignment='top')
                    axs[n].set_yticks(get_yticks(axs[n].viewLim.intervaly))
                axs[0].annotate('dAMF', xy=(-0.15, 0.5), xycoords='axes fraction', horizontalalignment='center',
                                verticalalignment='center', rotation=90)
                axs[3].legend(loc='center left', bbox_to_anchor=(0.92, 0.5), bbox_transform=plt.gcf().transFigure, ncol=1)
                plt.savefig(r'temp/damf/damf_changes_with_c_in_different_h_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k]))
                # plt.show()
                plt.close()


                fig, axs = plt.subplots(1, 4, figsize=(20, 4))
                plt.subplots_adjust(top=0.95, bottom=0.15, right=0.915, left=0.04, hspace=0.12, wspace=0.17)
                axs = axs.flatten()
                for n in range(len(axs)):
                    sec = [1, 3, 6, -1]
                    for o in range(lut.shape[2]):
                        axs[n].plot(h, lut[sec[n], :, o], '-v', label='s={}'.format(s[o]), color=colors[o])
                    axs[n].set_xlabel('h')
                    axs[n].grid(linestyle='-.', linewidth=1)
                    axs[n].annotate('c={}'.format(c[sec[n]]), xy=(0.5, 0.9), xycoords='axes fraction',
                                    horizontalalignment='center', verticalalignment='top')
                    axs[n].annotate('({})'.format(order[n+4+num]), xy=(0.05, 0.95), xycoords='axes fraction',
                                    horizontalalignment='center', verticalalignment='top')
                    axs[n].set_yticks(get_yticks(axs[n].viewLim.intervaly))
                axs[0].annotate('dAMF', xy=(-0.15, 0.5), xycoords='axes fraction', horizontalalignment='center',
                                verticalalignment='center', rotation=90)
                axs[3].legend(loc='center left', bbox_to_anchor=(0.92, 0.5), bbox_transform=plt.gcf().transFigure, ncol=1)
                plt.savefig(r'temp/damf/damf_changes_with_h_in_different_s_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k]))
                # plt.show()
                plt.close()


                fig, axs = plt.subplots(1, 4, figsize=(20, 4))
                plt.subplots_adjust(top=0.95, bottom=0.15, right=0.915, left=0.04, hspace=0.12, wspace=0.17)
                axs = axs.flatten()
                for n in range(len(axs)):
                    sec = [0, 5, 10, -1]
                    for o in range(lut.shape[0]):
                        axs[n].plot(s, lut[o, sec[n], :], '-v', label='c={}'.format(c[o]), color=colors[o])
                    axs[n].set_xlabel('s')
                    axs[n].grid(linestyle='-.', linewidth=1)
                    axs[n].annotate('h={}'.format(h[sec[n]]), xy=(0.5, 0.9), xycoords='axes fraction',
                                    horizontalalignment='center', verticalalignment='top')
                    axs[n].annotate('({})'.format(order[n+8+num]), xy=(0.05, 0.95), xycoords='axes fraction',
                                    horizontalalignment='center', verticalalignment='top')
                    axs[n].set_yticks(get_yticks(axs[n].viewLim.intervaly))
                axs[0].annotate('dAMF', xy=(-0.15, 0.5), xycoords='axes fraction', horizontalalignment='center',
                                verticalalignment='center', rotation=90)
                axs[3].legend(loc='center left', bbox_to_anchor=(0.92, 0.5), bbox_transform=plt.gcf().transFigure, ncol=1)
                plt.savefig(r'temp/damf/damf_changes_with_s_in_different_c_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k]))
                # plt.show()
                plt.close()
                num = num+12

merged = 1
if merged == 1:
    order = ['a', 'b', 'c', 'd']
    # pick_sza
    np_jpg = []
    for i in pick_sza:
        # pick_raa
        for j in pick_raa:
            # pick_elev
            for k in pick_elev:
                np_jpg.append(plt.imread(r'temp/damf/damf_changes_with_c_in_different_h_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k])))
                np_jpg.append(plt.imread(r'temp/damf/damf_changes_with_h_in_different_s_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k])))
                np_jpg.append(plt.imread(r'temp/damf/damf_changes_with_s_in_different_c_at_sza{}raa{}elev{}.jpg'.
                            format(sza[i], raa[j], elev[k])))
    np_jpg = np.array(np_jpg)
    shapes = np_jpg.shape
    shape2 = [shapes[0]*shapes[1]] + list(shapes[2:])
    np_jpg = np_jpg.reshape(shape2)
    plt.imsave(r'temp/damf_changes_with_chs.eps', np_jpg)
    plt.imsave(r'temp/damf_changes_with_chs.png', np_jpg)